{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  MSD 歌曲推荐——协同过滤推荐预处理\n",
    "1. 隐式播放次数 --> 显式打分\n",
    "2. 训练数据/测试数据分割\n",
    "3. 对训练数据，建立倒排表，比实时查询数据库快\n",
    "4. 计算每个用户的平均打分\n",
    "5. 对训练数据，预计算所有用户之间的相似度,保存用户相似度矩阵\n",
    "6. 对训练数据，预计算所有物品之间的相似度，保存物品相似度矩阵\n",
    "7. 用训练数据训练SVD模型，保存SVD模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding:utf-8 -*-\n",
    "#import sys\n",
    "#reload (sys)\n",
    "#sys.setdefaultencoding(\"utf-8\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-09-26T05:34:54.624167Z",
     "start_time": "2017-09-26T05:34:46.420964Z"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from collections import defaultdict\n",
    "import scipy.sparse as ss\n",
    "\n",
    "#保存数据\n",
    "import pickle as cPickle\n",
    "import scipy.io as sio\n",
    "\n",
    "#距离\n",
    "import scipy.spatial.distance as ssd\n",
    "\n",
    "from numpy.random import random\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据\n",
    "用户(800)、歌曲（800）及其播放次数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dpath = './data/'\n",
    "df_triplet = pd.read_csv( dpath  + 'triplet_dataset_sub.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user</th>\n",
       "      <th>song</th>\n",
       "      <th>play_count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCKSGZ12A58A7CA4B</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCVTLJ12A6310F0FD</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SODLLYS12A8C13A96B</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOEGIYH12A6D4FC0E3</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOFRQTD12A81C233C0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOHEMBB12A6701E907</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOHJOLH12A6310DFE5</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOIZLKI12A6D4F7B61</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOJGSIO12A8C141DBF</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOKEYJQ12A6D4F6132</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       user                song  play_count\n",
       "0  4e11f45d732f4861772b2906f81a7d384552ad12  SOCKSGZ12A58A7CA4B           1\n",
       "1  4e11f45d732f4861772b2906f81a7d384552ad12  SOCVTLJ12A6310F0FD           1\n",
       "2  4e11f45d732f4861772b2906f81a7d384552ad12  SODLLYS12A8C13A96B           3\n",
       "3  4e11f45d732f4861772b2906f81a7d384552ad12  SOEGIYH12A6D4FC0E3           1\n",
       "4  4e11f45d732f4861772b2906f81a7d384552ad12  SOFRQTD12A81C233C0           2\n",
       "5  4e11f45d732f4861772b2906f81a7d384552ad12  SOHEMBB12A6701E907           1\n",
       "6  4e11f45d732f4861772b2906f81a7d384552ad12  SOHJOLH12A6310DFE5           1\n",
       "7  4e11f45d732f4861772b2906f81a7d384552ad12  SOIZLKI12A6D4F7B61           1\n",
       "8  4e11f45d732f4861772b2906f81a7d384552ad12  SOJGSIO12A8C141DBF           1\n",
       "9  4e11f45d732f4861772b2906f81a7d384552ad12  SOKEYJQ12A6D4F6132           1"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_triplet.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(37519, 3)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_triplet.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 打分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#每个用户的总播放次数\n",
    "df_triplet_users = df_triplet[['user','play_count']].groupby('user').sum().reset_index()\n",
    "df_triplet_users.rename(columns={'play_count':'total_play_count'},inplace=True)\n",
    "\n",
    "#每首歌曲的播放比例\n",
    "df_triplet = pd.merge(df_triplet, df_triplet_users)\n",
    "df_triplet['fractional_play_count'] = df_triplet['play_count']/df_triplet['total_play_count']\n",
    "del df_triplet_users"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user</th>\n",
       "      <th>song</th>\n",
       "      <th>play_count</th>\n",
       "      <th>total_play_count</th>\n",
       "      <th>fractional_play_count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCKSGZ12A58A7CA4B</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "      <td>0.003861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOCVTLJ12A6310F0FD</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "      <td>0.003861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SODLLYS12A8C13A96B</td>\n",
       "      <td>3</td>\n",
       "      <td>259</td>\n",
       "      <td>0.011583</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOEGIYH12A6D4FC0E3</td>\n",
       "      <td>1</td>\n",
       "      <td>259</td>\n",
       "      <td>0.003861</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4e11f45d732f4861772b2906f81a7d384552ad12</td>\n",
       "      <td>SOFRQTD12A81C233C0</td>\n",
       "      <td>2</td>\n",
       "      <td>259</td>\n",
       "      <td>0.007722</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       user                song  play_count  \\\n",
       "0  4e11f45d732f4861772b2906f81a7d384552ad12  SOCKSGZ12A58A7CA4B           1   \n",
       "1  4e11f45d732f4861772b2906f81a7d384552ad12  SOCVTLJ12A6310F0FD           1   \n",
       "2  4e11f45d732f4861772b2906f81a7d384552ad12  SODLLYS12A8C13A96B           3   \n",
       "3  4e11f45d732f4861772b2906f81a7d384552ad12  SOEGIYH12A6D4FC0E3           1   \n",
       "4  4e11f45d732f4861772b2906f81a7d384552ad12  SOFRQTD12A81C233C0           2   \n",
       "\n",
       "   total_play_count  fractional_play_count  \n",
       "0               259               0.003861  \n",
       "1               259               0.003861  \n",
       "2               259               0.011583  \n",
       "3               259               0.003861  \n",
       "4               259               0.007722  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_triplet.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 训练数据/测试数据分割"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "total_index = df_triplet.index\n",
    "\n",
    "train_index, test_index = train_test_split(total_index, test_size = 0.4,random_state = 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_triplet_train = df_triplet.iloc[train_index]\n",
    "df_triplet_test = df_triplet.iloc[test_index]\n",
    "\n",
    "df_triplet_train.to_csv(path_or_buf= dpath + 'triplet_dataset_sub_train.csv')\n",
    "df_triplet_test.to_csv(path_or_buf= dpath + 'triplet_dataset_sub_test.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 对训练数据，事先计算好倒排表，比实时查询数据库快\n",
    "用户和item重新建索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of Users :781\n",
      "number of Songs :800\n"
     ]
    }
   ],
   "source": [
    "#所用的用户和item\n",
    "users = list(df_triplet_train['user'].unique())\n",
    "items = list(df_triplet_train['song'].unique())\n",
    "n_users = len(users)\n",
    "n_items = len(items)\n",
    "\n",
    "print(\"number of Users :%d\" % n_users)\n",
    "print(\"number of Songs :%d\" % n_items)\n",
    "\n",
    "#倒排表\n",
    "#统计每个用户播放过的歌曲   / 播放每个歌曲的用户\n",
    "user_items = defaultdict(set)\n",
    "item_users = defaultdict(set)\n",
    "\n",
    "#用户-物品关系矩阵表，稀疏矩阵，\n",
    "user_item_scores = ss.dok_matrix((n_users, n_items))\n",
    "\n",
    "#重新编码用户索引字典\n",
    "users_index = dict()\n",
    "items_index = dict()\n",
    "for i, u in enumerate(users):\n",
    "    users_index[u] = i\n",
    "\n",
    "\n",
    "#重新编码活动索引字典    \n",
    "for i, e in enumerate(items):\n",
    "    items_index[e] = i\n",
    "\n",
    "n_records = df_triplet_train.shape[0]\n",
    "for i in range(n_records):\n",
    "    user_index_i = users_index[df_triplet_train.iloc[i]['user'] ] #用户\n",
    "    item_index_i = items_index[df_triplet_train.iloc[i]['song'] ]#歌曲\n",
    "    \n",
    "    user_items[user_index_i].add(item_index_i)    #该用户的歌曲\n",
    "    item_users[item_index_i].add(user_index_i)    #播放该歌曲的用户\n",
    "        \n",
    "    score = df_triplet_train.iloc[i]['fractional_play_count']  #播放次数的比例\n",
    "    user_item_scores[user_index_i, item_index_i] = score\n",
    "\n",
    "#倒排表\n",
    "cPickle.dump(user_items, open(\"user_items.pkl\", 'wb'))\n",
    "cPickle.dump(item_users, open(\"item_users.pkl\", 'wb'))\n",
    "\n",
    "#保存用户-物品关系矩阵R，以备后用\n",
    "sio.mmwrite(\"user_item_scores\", user_item_scores)\n",
    "\n",
    "\n",
    "#保存用户索引表\n",
    "cPickle.dump(users_index, open(\"users_index.pkl\", 'wb'))\n",
    "#保存活动索引表\n",
    "cPickle.dump(items_index, open(\"items_index.pkl\", 'wb'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 计算每个用户的平均打分 和所有用户的平均打分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "users_mu = np.zeros(n_users)\n",
    "for u in range(n_users):  \n",
    "    n_user_items = 0\n",
    "    r_acc = 0.0\n",
    "    \n",
    "    for i in user_items[u]:  #用户打过分的item\n",
    "        r_acc += user_item_scores[u,i]\n",
    "        n_user_items += 1\n",
    " \n",
    "    users_mu[u] = r_acc/n_user_items\n",
    "\n",
    "cPickle.dump(users_mu, open(\"users_mu.pkl\", 'wb')) \n",
    "\n",
    "#所有用户的平均打分\n",
    "mu = df_triplet_train['fractional_play_count'].mean()  #average rating\n",
    "cPickle.dump(mu, open(\"mu.pkl\", 'wb'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.  预先计算好所有用户之间的相似度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1.1 计算两个用户之间的相似度\n",
    "以播放比例为特征"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def user_similarity_playcount(uid1, uid2):\n",
    "    si={}  #有效item（两个用户均有打分的item）的集合\n",
    "    for item in user_items[uid1]:  #uid1所有打过分的Item1\n",
    "        if item in user_items[uid2]:  #如果uid2也对该Item打过分\n",
    "            si[item]=1  #item为一个有效item\n",
    "        \n",
    "    n=len(si)   #有效item数，有效item为即对uid对Item打过分，uid2也对Item打过分\n",
    "    if (n==0):  #没有共同打过分的item，相似度设为0？\n",
    "        similarity=0.0  \n",
    "        return similarity  \n",
    "        \n",
    "    #用户uid1的有效打分(减去该用户的平均打分)\n",
    "    s1=np.array([user_item_scores[uid1,item]-users_mu[uid1] for item in si])  \n",
    "        \n",
    "    #用户uid2的有效打分(减去该用户的平均打分)\n",
    "    s2=np.array([user_item_scores[uid2,item]-users_mu[uid2] for item in si])  \n",
    "        \n",
    "    similarity = 1 - ssd.cosine(s1, s2) \n",
    "    \n",
    "    if np.isnan(similarity): #s1或s2的l2模为0（全部等于该用户的平均打分）\n",
    "        similarity = 0.0\n",
    "    return similarity  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1.2 计算两个用户之间的相似度\n",
    "以是否播放过歌曲为特征"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def user_similarity_played(uid1, uid2 ):\n",
    "    #得到uid1的特征表示：Calculate unique items of item uid1\n",
    "    s1 = user_items[uid1] \n",
    "    \n",
    "    #得到uid1的特征表示：Calculate unique items of item uid1\n",
    "    s2 = user_items[uid2]\n",
    "        \n",
    "    #Calculate intersection of songs played by uid1 and uid2\n",
    "    intersection = s1.intersection(s2)\n",
    "                \n",
    "    #Calculate Jaccard Index\n",
    "    if len(intersection) != 0:\n",
    "        #Calculate union of songs played by uid1 and uid2\n",
    "        union = s1.union(s2)\n",
    "        similarity = float(len(intersection))/float(len(union))\n",
    "    else:\n",
    "        similarity = 0\n",
    "\n",
    "    return similarity  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 计算好所有用户之间的相似性\n",
    "对用户比较少、用户比较固定的的系统适用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ui=0 \n",
      "ui=100 \n",
      "ui=200 \n",
      "ui=300 \n",
      "ui=400 \n",
      "ui=500 \n",
      "ui=600 \n",
      "ui=700 \n"
     ]
    }
   ],
   "source": [
    "users_similarity_matrix = np.matrix(np.zeros(shape=(n_users, n_users)), float)\n",
    "\n",
    "for ui in range(n_users):\n",
    "    users_similarity_matrix[ui,ui] = 1.0\n",
    "    \n",
    "    #打印进度条\n",
    "    if(ui % 100 == 0):\n",
    "        print (\"ui=%d \" % (ui))\n",
    "\n",
    "    for uj in range(ui+1,n_users):   \n",
    "        users_similarity_matrix[uj,ui] = user_similarity_played(ui, uj)\n",
    "        users_similarity_matrix[ui,uj] = users_similarity_matrix[uj,ui]\n",
    "\n",
    "cPickle.dump(users_similarity_matrix, open(\"users_similarity_played.pkl\", 'wb')) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. 事先计算好所有item之间的相似性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1.1 计算两个item之间的相似度\n",
    "以播放次数/播放比例为特征"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def item_similarity_playcount(iid1, iid2):\n",
    "    su={}  #有效item（两个用户均有打分的item）的集合\n",
    "    for user in item_users[iid1]:  #对iid1所有打过分的用户\n",
    "        if user in item_users[iid2]:  #如果该用户对iid2也打过分\n",
    "            su[user]=1  #该用户为一个有效user\n",
    "        \n",
    "    n=len(su)   #有效item数，有效item为即对uid对Item打过分，uid2也对Item打过分\n",
    "    if (n==0):  #没有共同打过分的item，相似度设为0？\n",
    "        similarity=0  \n",
    "        return similarity  \n",
    "        \n",
    "    #iid1的有效打分(减去用户的平均打分)\n",
    "    s1=np.array([user_item_scores[user,iid1]-users_mu[user] for user in su])\n",
    "        \n",
    "    #iid2的有效打分(减去用户的平均打分)\n",
    "    s2=np.array([user_item_scores[user,iid2]-users_mu[user] for user in su])  \n",
    "    \n",
    "    similarity = 1 - ssd.cosine(s1, s2) \n",
    "    if( np.isnan(similarity) ):  #分母为0（s1或s2中所有元素为0）\n",
    "        similarity = 0.0\n",
    "    return similarity  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1.2 计算两个item之间的相似度\n",
    "以是否播放为特征\n",
    "比以播放次数为特征计算快"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def item_similarity_played(iid1, iid2 ):\n",
    "    #得到iid1的特征表示：Calculate unique users of iid1\n",
    "    s1 = item_users[iid1] \n",
    "    \n",
    "    #得到iid2的特征表示：Calculate unique users of iid2\n",
    "    s2 = item_users[iid2]\n",
    "        \n",
    "    #Calculate intersection of users played iid1 and iid2\n",
    "    intersection = s1.intersection(s2)\n",
    "                \n",
    "    #Calculate Jaccard Index\n",
    "    if len(intersection) != 0:\n",
    "        #Calculate union of songs played by uid1 and uid2\n",
    "        union = s1.union(s2)\n",
    "        similarity = float(len(intersection))/float(len(union))\n",
    "    else:\n",
    "        similarity = 0\n",
    "\n",
    "    return similarity  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 计算所有item之间的相似性\n",
    "对item比较少、Item比较固定的系统适用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i=0 \n",
      "i=100 \n",
      "i=200 \n",
      "i=300 \n",
      "i=400 \n",
      "i=500 \n",
      "i=600 \n",
      "i=700 \n"
     ]
    }
   ],
   "source": [
    "items_similarity_matrix = np.matrix(np.zeros(shape=(n_items, n_items)), float)\n",
    "\n",
    "for i in range(n_items):\n",
    "    items_similarity_matrix[i,i] = 1.0\n",
    "    \n",
    "    #打印进度条\n",
    "    if(i % 100 == 0):\n",
    "        print (\"i=%d \" % (i) )\n",
    "\n",
    "    for j in range(i+1,n_items):   #items by user \n",
    "        items_similarity_matrix[j,i] = item_similarity_played(i, j)\n",
    "        items_similarity_matrix[i,j] = items_similarity_matrix[j,i]\n",
    "\n",
    "cPickle.dump(items_similarity_matrix, open(\"items_similarity_played.pkl\", 'wb')) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## 7. SVD模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.1 模型初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "#隐含变量的维数\n",
    "K = 40\n",
    "\n",
    "#item和用户的偏置项\n",
    "bi = np.zeros((n_items,1))    \n",
    "bu = np.zeros((n_users,1))   \n",
    "\n",
    "#item和用户的隐含向量\n",
    "qi =  np.zeros((n_items,K))    \n",
    "pu =  np.zeros((n_users,K))   \n",
    "\n",
    "#隐含向量初始化\n",
    "for uid in range(n_users):  #对每个用户\n",
    "    pu[uid] = np.reshape(random((K,1))/10*(np.sqrt(K)),K)\n",
    "       \n",
    "for iid in range(n_items):  #对每个item\n",
    "    qi[iid] = np.reshape(random((K,1))/10*(np.sqrt(K)),K)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.2 根据当前参数，预测用户uid对Item（iid）的打分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def svd_pred(uid, iid):  \n",
    "    score = mu + bi[iid] + bu[uid] + np.sum(qi[iid]* pu[uid])  \n",
    "        \n",
    "    #将打分范围控制在1-5之间\n",
    "    #if score>5:  \n",
    "        #score = 5  \n",
    "    #elif score<1:  \n",
    "        #score = 1  \n",
    "        \n",
    "    return score  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.3 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The 0-th  step is running\n",
      "the rmse of this step on train data is  [1.01084031]\n",
      "The 1-th  step is running\n",
      "the rmse of this step on train data is  [0.17780972]\n",
      "The 2-th  step is running\n",
      "the rmse of this step on train data is  [0.11156777]\n",
      "The 3-th  step is running\n",
      "the rmse of this step on train data is  [0.08873554]\n",
      "The 4-th  step is running\n",
      "the rmse of this step on train data is  [0.07794535]\n",
      "The 5-th  step is running\n",
      "the rmse of this step on train data is  [0.07146943]\n",
      "The 6-th  step is running\n",
      "the rmse of this step on train data is  [0.06713242]\n",
      "The 7-th  step is running\n",
      "the rmse of this step on train data is  [0.06366775]\n",
      "The 8-th  step is running\n",
      "the rmse of this step on train data is  [0.06088709]\n",
      "The 9-th  step is running\n",
      "the rmse of this step on train data is  [0.0587144]\n",
      "The 10-th  step is running\n",
      "the rmse of this step on train data is  [0.05674141]\n",
      "The 11-th  step is running\n",
      "the rmse of this step on train data is  [0.05530473]\n",
      "The 12-th  step is running\n",
      "the rmse of this step on train data is  [0.05392047]\n",
      "The 13-th  step is running\n",
      "the rmse of this step on train data is  [0.05256819]\n",
      "The 14-th  step is running\n",
      "the rmse of this step on train data is  [0.05159431]\n",
      "The 15-th  step is running\n",
      "the rmse of this step on train data is  [0.0506909]\n",
      "The 16-th  step is running\n",
      "the rmse of this step on train data is  [0.0498129]\n",
      "The 17-th  step is running\n",
      "the rmse of this step on train data is  [0.04910679]\n",
      "The 18-th  step is running\n",
      "the rmse of this step on train data is  [0.04844261]\n",
      "The 19-th  step is running\n",
      "the rmse of this step on train data is  [0.04789528]\n",
      "The 20-th  step is running\n",
      "the rmse of this step on train data is  [0.04743657]\n",
      "The 21-th  step is running\n",
      "the rmse of this step on train data is  [0.04699682]\n",
      "The 22-th  step is running\n",
      "the rmse of this step on train data is  [0.04655829]\n",
      "The 23-th  step is running\n",
      "the rmse of this step on train data is  [0.04620861]\n",
      "The 24-th  step is running\n",
      "the rmse of this step on train data is  [0.04589235]\n",
      "The 25-th  step is running\n",
      "the rmse of this step on train data is  [0.04558881]\n",
      "The 26-th  step is running\n",
      "the rmse of this step on train data is  [0.04531016]\n",
      "The 27-th  step is running\n",
      "the rmse of this step on train data is  [0.04506096]\n",
      "The 28-th  step is running\n",
      "the rmse of this step on train data is  [0.04485065]\n",
      "The 29-th  step is running\n",
      "the rmse of this step on train data is  [0.04465486]\n",
      "The 30-th  step is running\n",
      "the rmse of this step on train data is  [0.04449292]\n",
      "The 31-th  step is running\n",
      "the rmse of this step on train data is  [0.04431982]\n",
      "The 32-th  step is running\n",
      "the rmse of this step on train data is  [0.04417022]\n",
      "The 33-th  step is running\n",
      "the rmse of this step on train data is  [0.04402918]\n",
      "The 34-th  step is running\n",
      "the rmse of this step on train data is  [0.04391072]\n",
      "The 35-th  step is running\n",
      "the rmse of this step on train data is  [0.0437883]\n",
      "The 36-th  step is running\n",
      "the rmse of this step on train data is  [0.04369374]\n",
      "The 37-th  step is running\n",
      "the rmse of this step on train data is  [0.0435992]\n",
      "The 38-th  step is running\n",
      "the rmse of this step on train data is  [0.04349806]\n",
      "The 39-th  step is running\n",
      "the rmse of this step on train data is  [0.04341909]\n",
      "The 40-th  step is running\n",
      "the rmse of this step on train data is  [0.04334683]\n",
      "The 41-th  step is running\n",
      "the rmse of this step on train data is  [0.04329585]\n",
      "The 42-th  step is running\n",
      "the rmse of this step on train data is  [0.04322163]\n",
      "The 43-th  step is running\n",
      "the rmse of this step on train data is  [0.0431703]\n",
      "The 44-th  step is running\n",
      "the rmse of this step on train data is  [0.04311091]\n",
      "The 45-th  step is running\n",
      "the rmse of this step on train data is  [0.04305925]\n",
      "The 46-th  step is running\n",
      "the rmse of this step on train data is  [0.0430167]\n",
      "The 47-th  step is running\n",
      "the rmse of this step on train data is  [0.04297466]\n",
      "The 48-th  step is running\n",
      "the rmse of this step on train data is  [0.04293484]\n",
      "The 49-th  step is running\n",
      "the rmse of this step on train data is  [0.04289949]\n"
     ]
    }
   ],
   "source": [
    "#gamma：为学习率\n",
    "#Lambda：正则参数\n",
    "#steps：迭代次数\n",
    "\n",
    "steps=50\n",
    "gamma=0.04\n",
    "Lambda=0.15\n",
    "\n",
    "#总的打分记录数目\n",
    "n_records = df_triplet_train.shape[0]\n",
    "\n",
    "for step in range(steps):  \n",
    "    print ('The ' + str(step) + '-th  step is running' )\n",
    "    rmse_sum=0.0 \n",
    "            \n",
    "    #将训练样本打散顺序\n",
    "    kk = np.random.permutation(n_records)  \n",
    "    for j in range(n_records):  \n",
    "        #每次一个训练样本\n",
    "        line = kk[j]  \n",
    "        \n",
    "        uid = users_index [df_triplet_train.iloc[line]['user']]\n",
    "        iid = items_index [df_triplet_train.iloc[line]['song']]\n",
    "    \n",
    "        rating  = df_triplet_train.iloc[line]['fractional_play_count']\n",
    "                \n",
    "        #预测残差\n",
    "        eui = rating - svd_pred(uid, iid)  \n",
    "        #残差平方和\n",
    "        rmse_sum += eui**2  \n",
    "                \n",
    "        #随机梯度下降，更新\n",
    "        bu[uid] += gamma * (eui - Lambda * bu[uid])  \n",
    "        bi[iid] += gamma * (eui - Lambda * bi[iid]) \n",
    "                \n",
    "        temp = qi[iid]  \n",
    "        qi[iid] += gamma * (eui* pu[uid]- Lambda*qi[iid] )  \n",
    "        pu[uid] += gamma * (eui* temp - Lambda*pu[uid])  \n",
    "            \n",
    "    #学习率递减\n",
    "    gamma=gamma*0.93  \n",
    "    print (\"the rmse of this step on train data is \",np.sqrt(rmse_sum/n_records))  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.3 保存模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A method for saving object data to JSON file\n",
    "import json\n",
    "def save_json(filepath):\n",
    "    dict_ = {}\n",
    "    dict_['mu'] = mu\n",
    "    dict_['K'] = K\n",
    "    \n",
    "    dict_['bi'] = bi.tolist()\n",
    "    dict_['bu'] = bu.tolist()\n",
    "    \n",
    "    dict_['qi'] = qi.tolist()\n",
    "    dict_['pu'] = pu.tolist()\n",
    "\n",
    "    # Creat json and save to file\n",
    "    json_txt = json.dumps(dict_)\n",
    "    with open(filepath, 'w') as file:\n",
    "        file.write(json_txt)\n",
    "save_json('svd_model.json')"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "toc": {
   "nav_menu": {
    "height": "153px",
    "width": "160px"
   },
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": true,
   "toc_position": {
    "height": "691px",
    "left": "0px",
    "right": "1405px",
    "top": "106px",
    "width": "303px"
   },
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
